"""Add a RabbitMQ server to the mix."""

import base64
import contextlib
import datetime
import os
import pathlib
import ssl
import tempfile
import typing
from unittest import mock

from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
import pika

from . import cluster
from .. import misc
from ..logger import get_logger
from ..session import get_session

LOGGER = get_logger(__name__)
SESSION = get_session(__name__)


class RabbitMQServer(cluster.KubernetesCluster):
    """Add a RabbitMQ server to the mix."""

    @classmethod
    def setUpClass(cls) -> None:
        """Set up the service."""
        super().setUpClass()
        cls.enterClassContext(cls._rabbitmq())

    @classmethod
    @contextlib.contextmanager
    def _rabbitmq(cls) -> typing.Iterator[None]:
        service_id = 'rabbitmq'
        now = misc.now_tz_utc()
        with cls.k8s_namespace(service_id):
            LOGGER.info('Starting RabbitMQ')
            secret_data = cls._crypto_setup(now)

            cls.k8s_apply(namespace=service_id, body={
                'apiVersion': 'v1', 'kind': 'Secret',
                'metadata': {'name': service_id},
                'data': {k: base64.b64encode(v).decode('ascii') for k, v in secret_data.items()},
            })
            cls.k8s_apply(namespace=service_id, body={
                'apiVersion': 'v1', 'kind': 'ConfigMap',
                'metadata': {'name': service_id},
                'data': {
                    'ssl.conf':
                        'listeners.ssl.default     = 5671\n'
                        'ssl_options.cacertfile    = /secrets/ca_certificate.pem\n'
                        'ssl_options.certfile      = /secrets/server_certificate.pem\n'
                        'ssl_options.keyfile       = /secrets/server_private_key.pem\n'
                        'ssl_options.versions.1    = tlsv1.2\n'
                        'ssl_options.verify        = verify_peer\n'
                        'ssl_options.fail_if_no_peer_cert = true\n'
                        'ssl_cert_login_from       = common_name\n'
                        'auth_mechanisms.1         = PLAIN\n'
                        'auth_mechanisms.2         = EXTERNAL\n'
                        'management.tcp.port       = 15672\n'
                        'management.ssl.port       = 15671\n'
                        'management.ssl.cacertfile = /secrets/ca_certificate.pem\n'
                        'management.ssl.certfile   = /secrets/server_certificate.pem\n'
                        'management.ssl.keyfile    = /secrets/server_private_key.pem\n'
                        'stomp.listeners.ssl.1     = 61614\n'
                        'stomp.ssl_cert_login      = true\n',
                    'enabled_plugins':
                        '[rabbitmq_management,rabbitmq_stomp,rabbitmq_auth_mechanism_ssl].'
                }})
            cls.k8s_deployment(namespace=service_id, name=service_id, setup_at=now, container={
                'image': 'rabbitmq:3-management',
                'startupProbe': cls.k8s_startup_probe(5671),  # last listener to start
                'volumeMounts': [
                    {'name': 'secrets', 'mountPath': '/secrets', 'readOnly': True},
                    {'name': 'config', 'mountPath': '/etc/rabbitmq/conf.d/ssl.conf',
                     'subPath': 'ssl.conf', 'readOnly': True},
                    {'name': 'config', 'mountPath': '/etc/rabbitmq/enabled_plugins',
                     'subPath': 'enabled_plugins'},
                ],
            }, volumes=[
                {'name': 'secrets', 'secret': {'secretName': service_id}},
                {'name': 'config', 'configMap': {'name': service_id}},
            ])
            cls.k8s_service(namespace=service_id, name=service_id)
            if not cls.k8s_wait(namespace=service_id, name=service_id, setup_at=now):
                raise Exception(f'{service_id} did not start up')

            with tempfile.TemporaryDirectory() as directory:
                (ca_certificate := pathlib.Path(directory, 'ca_certificate.pem')).write_bytes(
                    secret_data['ca_certificate.pem'])
                (client_combined := pathlib.Path(directory, 'client_combined.pem')).write_bytes(
                    secret_data['client_private_key.pem'] + secret_data['client_certificate.pem'])

                ssl_options = pika.SSLOptions(ssl.create_default_context(cafile=ca_certificate))
                ssl_options.context.load_cert_chain(client_combined)
                connection = pika.BlockingConnection([pika.ConnectionParameters(
                    host=cls.hostname, port=5671, virtual_host='/',
                    credentials=pika.credentials.ExternalCredentials(),
                    ssl_options=ssl_options,
                )])
                channel = connection.channel()
                for exchange in ('webhooks', 'retry.incoming', 'retry.outgoing'):
                    channel.exchange_declare(f'cki.exchange.{exchange}')
                connection.close()

                SESSION.put(
                    f'http://{cls.hostname}:15672/api/policies/%2F/tests-ttl',
                    verify=ca_certificate,
                    auth=('guest', 'guest'),
                    json={
                        'vhost': '/',
                        'name': 'tests-ttl',
                                'pattern': r'cki\.queue\.retry\..*',
                                'apply-to': 'queues',
                                'definition': {'message-ttl': 1000},
                    },
                ).raise_for_status()

                with mock.patch.dict(os.environ, {
                    'RABBITMQ_HOST': cls.hostname,
                    'RABBITMQ_PORT': '5671',
                    'RABBITMQ_MANAGEMENT_PORT': '15671',
                    'RABBITMQ_VIRTUAL_HOST': '/',
                    'RABBITMQ_USER': 'guest',
                    'RABBITMQ_PASSWORD': 'guest',
                    'RABBITMQ_CAFILE': str(ca_certificate),
                    'RABBITMQ_CERTFILE': str(client_combined),
                    'STOMP_HOST': cls.hostname,
                    'STOMP_PORT': '61614',
                    'STOMP_CERTFILE': str(client_combined),
                }):
                    yield

    @classmethod
    def _crypto_setup(cls, now: datetime.datetime) -> dict[str, bytes]:
        keys = {
            'ca': rsa.generate_private_key(65537, 2048),
            'server': rsa.generate_private_key(65537, 2048),
            'client': rsa.generate_private_key(65537, 2048),
        }

        certs = {
            'ca': (
                x509.CertificateBuilder()
                .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, 'RabbitMQ CA')]))
                .issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, 'RabbitMQ CA')]))
                .public_key(keys['ca'].public_key())
                .serial_number(x509.random_serial_number())
                .not_valid_before(now)
                .not_valid_after(now + datetime.timedelta(days=365000))
                .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
                .add_extension(x509.KeyUsage(
                    digital_signature=False, content_commitment=False, key_encipherment=False,
                    data_encipherment=False, key_agreement=False, key_cert_sign=True,
                    crl_sign=True, encipher_only=False, decipher_only=False,
                ), critical=False)
                .add_extension(
                    x509.SubjectKeyIdentifier.from_public_key(keys['ca'].public_key()),
                    critical=False
                )
                .add_extension(
                    x509.AuthorityKeyIdentifier.from_issuer_public_key(keys['ca'].public_key()),
                    critical=False,
                )
            ).sign(keys['ca'], hashes.SHA256()),
            'server': (
                x509.CertificateBuilder()
                .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, 'host')]))
                .issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, 'RabbitMQ CA')]))
                .public_key(keys['server'].public_key())
                .serial_number(x509.random_serial_number())
                .not_valid_before(now)
                .not_valid_after(now + datetime.timedelta(days=365000))
                .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False)
                .add_extension(x509.KeyUsage(
                    digital_signature=True, content_commitment=False, key_encipherment=True,
                    data_encipherment=False, key_agreement=False, key_cert_sign=False,
                    crl_sign=False, encipher_only=False, decipher_only=False,
                ), critical=False)
                .add_extension(x509.ExtendedKeyUsage([
                    x509.oid.ExtendedKeyUsageOID.SERVER_AUTH,
                ]), critical=False)
                .add_extension(
                    x509.SubjectKeyIdentifier.from_public_key(keys['ca'].public_key()),
                    critical=False
                )
                .add_extension(
                    x509.AuthorityKeyIdentifier.from_issuer_public_key(keys['ca'].public_key()),
                    critical=False,
                )
                .add_extension(x509.SubjectAlternativeName([
                    x509.DNSName(cls.hostname),
                    x509.DNSName('rabbitmq.rabbitmq.svc.cluster.local'),
                ]), critical=False)
            ).sign(keys['ca'], hashes.SHA256()),
            'client': (
                x509.CertificateBuilder()
                .subject_name(x509.Name([
                    x509.NameAttribute(NameOID.COMMON_NAME, 'guest'),
                    x509.NameAttribute(NameOID.ORGANIZATION_NAME, 'client'),
                ]))
                .issuer_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, 'RabbitMQ CA')]))
                .public_key(keys['client'].public_key())
                .serial_number(x509.random_serial_number())
                .not_valid_before(now)
                .not_valid_after(now + datetime.timedelta(days=365000))
                .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False)
                .add_extension(x509.KeyUsage(
                    digital_signature=True, content_commitment=False, key_encipherment=True,
                    data_encipherment=False, key_agreement=False, key_cert_sign=False,
                    crl_sign=False, encipher_only=False, decipher_only=False,
                ), critical=False)
                .add_extension(x509.ExtendedKeyUsage([
                    x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH,
                ]), critical=False)
                .add_extension(
                    x509.SubjectKeyIdentifier.from_public_key(keys['client'].public_key()),
                    critical=False
                )
                .add_extension(
                    x509.AuthorityKeyIdentifier.from_issuer_public_key(keys['ca'].public_key()),
                    critical=False,
                )
            ).sign(keys['ca'], hashes.SHA256()),
        }

        return {
            **{f'{n}_private_key.pem': k.private_bytes(
                serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8,
                serialization.NoEncryption()) for n, k in keys.items()},
            **{f'{n}_certificate.pem': c.public_bytes(serialization.Encoding.PEM)
                for n, c in certs.items()},
        }
